/**
 * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "unit_ecma_test.h"
#include "optimizer/ir/datatype.h"
#include "optimizer/ir/graph_cloner.h"
#include "optimizer/optimizations/cleanup.h"
#include "optimizer/optimizations/branch_elimination.h"

namespace ark::compiler {
class IrBranchEliminationTest : public AsmTest {
public:
    IrBranchEliminationTest() = default;
};

// NOLINTBEGIN(readability-magic-numbers)
TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNDiffTypesTrue)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    ASSERT_TRUE(graph->RunPass<BranchElimination>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 6)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNDiffTypesFalse1)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 4, 3)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<BranchElimination>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNDiffTypesFalse2)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<BranchElimination>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNEqTypesTrue1)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    ASSERT_TRUE(graph->RunPass<BranchElimination>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();

        BASIC_BLOCK(2, 3, 5)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNEqTypesTrue2)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 4, 3)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    ASSERT_TRUE(graph->RunPass<BranchElimination>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 6, 3)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeVNEqTypesTrue3)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    ASSERT_TRUE(graph->RunPass<BranchElimination>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();

        BASIC_BLOCK(2, 3, 5)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeSubtype1)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_STRING_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_HEAP_OBJECT_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    ASSERT_TRUE(graph->RunPass<BranchElimination>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();

        BASIC_BLOCK(2, 3, 6)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_STRING_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(8);
        }
    }

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(IrBranchEliminationTest, EliminateCompareAnyTypeSubtype2)
{
    auto graph = CreateGraphDynStubWithDefaultRuntime();
    GRAPH(graph)
    {
        PARAMETER(0, 0).any();
        PARAMETER(4, 1).s32();
        PARAMETER(8, 2).s32();
        PARAMETER(10, 3).s32();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(2, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_HEAP_OBJECT_TYPE).Inputs(0);
            INST(3, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(2);
        }

        BASIC_BLOCK(3, -1)
        {
            INST(5, Opcode::Return).s32().Inputs(4);
        }

        BASIC_BLOCK(4, 6, 5)
        {
            INST(6, Opcode::CompareAnyType).b().AnyType(AnyBaseType::ECMASCRIPT_STRING_TYPE).Inputs(0);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(6);
        }

        BASIC_BLOCK(5, -1)
        {
            INST(9, Opcode::Return).s32().Inputs(8);
        }

        BASIC_BLOCK(6, -1)
        {
            INST(11, Opcode::Return).s32().Inputs(10);
        }
    }

    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<BranchElimination>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());

    GraphChecker(graph).Check();

    EXPECT_TRUE(GraphComparator().Compare(graph, graphOpt));
}
// NOLINTEND(readability-magic-numbers)

}  // namespace ark::compiler
